/*
Copyright 2020 Caicloud Authors

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package executor

import (
	"context"
	"fmt"
	"io"
	"net/http"
	"path"
	"reflect"
	"runtime"
	"sort"

	"github.com/caicloud/nirvana/definition"
	"github.com/caicloud/nirvana/service"
)

// Executor executes with a context.
type Executor interface {
	MiddlewareExecutor

	ContentTypeMap() map[string][]string
	Acceptable(string) bool
	Producible([]string) bool
}

// DefinitionToExecutor generates a Executor for the Definition.
func DefinitionToExecutor(urlPath string, d definition.Definition, customCode int) (Executor, error) {
	var method string
	if d.Method == definition.Any {
		method = string(definition.Any)
	} else {
		method = service.HTTPMethodFor(d.Method)
	}
	if method == "" {
		return nil, DefinitionNoMethod.Error(d.Method, urlPath)
	}
	if len(d.Consumes) <= 0 {
		return nil, DefinitionNoConsumes.Error(d.Method, urlPath)
	}
	if len(d.Produces) <= 0 {
		return nil, DefinitionNoProduces.Error(d.Method, urlPath)
	}
	if len(d.ErrorProduces) <= 0 {
		return nil, DefinitionNoErrorProduces.Error(d.Method, urlPath)
	}
	if d.Function == nil {
		return nil, DefinitionNoFunction.Error(d.Method, urlPath)
	}
	value := reflect.ValueOf(d.Function)
	if value.Kind() != reflect.Func {
		return nil, DefinitionInvalidFunctionType.Error(value.Type(), d.Method, urlPath)
	}
	if customCode == 0 {
		customCode = service.HTTPCodeFor(d.Method)
	}
	c := &executor{
		method:   method,
		code:     customCode,
		function: value,
	}
	consumeAll := false
	consumes := map[string]bool{}
	for _, ct := range d.Consumes {
		if ct == definition.MIMEAll {
			consumeAll = true
			continue
		}
		if consumer := service.ConsumerFor(ct); consumer != nil {
			c.consumers = append(c.consumers, consumer)
			consumes[consumer.ContentType()] = true
		} else {
			return nil, DefinitionNoConsumer.Error(ct, d.Method, urlPath)
		}
	}
	if consumeAll {
		// Add remaining consumers to executor.
		for _, consumer := range service.AllConsumers() {
			if !consumes[consumer.ContentType()] {
				c.consumers = append(c.consumers, consumer)
			}
		}
	}
	produceAll := false
	produces := map[string]bool{}
	for _, ct := range d.Produces {
		if ct == definition.MIMEAll {
			produceAll = true
			continue
		}
		if producer := service.ProducerFor(ct); producer != nil {
			c.producers = append(c.producers, producer)
			produces[producer.ContentType()] = true
		} else {
			return nil, DefinitionNoProducer.Error(ct, d.Method, urlPath)
		}
	}
	if produceAll {
		// Add remaining producers to executor.
		for _, producer := range service.AllProducers() {
			if !produces[producer.ContentType()] {
				c.producers = append(c.producers, producer)
			}
		}
	}
	errorProduceAll := false
	errorProduces := map[string]bool{}
	for _, ct := range d.ErrorProduces {
		if ct == definition.MIMEAll {
			errorProduceAll = true
			continue
		}
		if producer := service.ProducerFor(ct); producer != nil {
			c.errorProducers = append(c.errorProducers, producer)
			errorProduces[producer.ContentType()] = true
		} else {
			return nil, DefinitionNoProducer.Error(ct, d.Method, urlPath)
		}
	}
	if errorProduceAll {
		// Add remaining producers to executor.
		for _, producer := range service.AllProducers() {
			if !errorProduces[producer.ContentType()] {
				c.errorProducers = append(c.errorProducers, producer)
			}
		}
	}
	// Get func name and file position.
	f := runtime.FuncForPC(value.Pointer())
	file, line := f.FileLine(value.Pointer())
	// Function name examples:
	// 1. Common function: api.CreateSomething(create.go#30)
	// 2. Anonymous function: api.glob..func1(create.go#30)
	//    Anonymous function names are generated by go. Don't explore their meaning.
	funcName := fmt.Sprintf("%s(%s#%d)", path.Base(f.Name()), path.Base(file), line)
	ps, err := generateParameters(urlPath, funcName, value.Type(), d.Parameters)
	if err != nil {
		return nil, err
	}
	c.parameters = ps
	rs, err := generateResults(urlPath, funcName, value.Type(), d.Results)
	if err != nil {
		return nil, err
	}
	c.results = rs
	return c, nil
}

func generateParameters(path, funcName string, typ reflect.Type, ps []definition.Parameter) ([]parameter, error) {
	if typ.NumIn() != len(ps) {
		return nil, DefinitionUnmatchedParameters.Error(funcName, typ.NumIn(), len(ps), path)
	}
	parameters := make([]parameter, 0, len(ps))
	for index, p := range ps {
		generator := service.ParameterGeneratorFor(p.Source)
		if generator == nil {
			return nil, service.NoParameterGenerator.Error(p.Source)
		}

		param := parameter{
			name:         p.Name,
			defaultValue: p.Default,
			generator:    generator,
			operators:    p.Operators,
			optional:     p.Optional,
		}
		if len(p.Operators) <= 0 {
			param.targetType = typ.In(index)
		} else {
			param.targetType = p.Operators[0].In()
		}
		if err := generator.Validate(param.name, param.defaultValue, param.targetType); err != nil {
			// Order from 0 is odd. So index+1.
			return nil, InvalidParameter.Error(order(index+1), funcName, err.Error())
		}
		if len(param.operators) > 0 {
			if err := validateOperators(param.targetType, typ.In(index), param.operators); err != nil {
				return nil, InvalidOperatorsForParameter.Error(order(index+1), funcName, err.Error())
			}
		}
		parameters = append(parameters, param)
	}
	return parameters, nil
}

func generateResults(path, funcName string, typ reflect.Type, rs []definition.Result) ([]result, error) {
	if typ.NumOut() != len(rs) {
		return nil, DefinitionUnmatchedResults.Error(funcName, typ.NumOut(), len(rs), path)
	}
	results := make([]result, 0, len(rs))
	for index, r := range rs {
		handler := service.DestinationHandlerFor(r.Destination)
		if handler == nil {
			return nil, NoDestinationHandler.Error(r.Destination)
		}
		result := result{
			index:     index,
			handler:   handler,
			operators: r.Operators,
		}
		outType := typ.Out(index)
		if len(result.operators) > 0 {
			LastOperatorOutType := result.operators[len(result.operators)-1].Out()
			if err := validateOperators(outType, LastOperatorOutType, result.operators); err != nil {
				return nil, InvalidOperatorsForResult.Error(order(index+1), funcName, err.Error())
			}
			outType = LastOperatorOutType
		}
		if err := handler.Validate(outType); err != nil {
			// Order from 0 is odd. So index+1.
			return nil, InvalidResult.Error(order(index+1), funcName, err.Error())
		}
		results = append(results, result)
	}
	sort.Sort(resultsSorter(results))
	return results, nil
}

// validateOperators checks if the chain is valid:
//   in -> operators[0].In()
//   operators[0].Out() -> operators[1].In()
//   ...
//   operators[N].Out() -> out
func validateOperators(in, out reflect.Type, operators []definition.Operator) error {
	if len(operators) <= 0 {
		return nil
	}
	index := 0
	for ; index < len(operators); index++ {
		operator := operators[index]
		if !in.AssignableTo(operator.In()) {
			// The out type of operator[index-1] is not compatible to operator[index].
			return invalidOperatorInType.Error(in, order(index+1))
		}
		in = operator.Out()
	}
	typ := operators[index-1].Out()
	if !typ.AssignableTo(out) {
		// The last operator is not compatible to out type.
		return invalidOperatorOutType.Error(order(index), out)
	}
	return nil
}

type executor struct {
	method         string
	code           int
	consumers      []service.Consumer
	producers      []service.Producer
	errorProducers []service.Producer
	parameters     []parameter
	results        []result
	function       reflect.Value
}

type parameter struct {
	name         string
	targetType   reflect.Type
	defaultValue interface{}
	generator    service.ParameterGenerator
	operators    []definition.Operator
	optional     bool
}

type result struct {
	index     int
	handler   service.DestinationHandler
	operators []definition.Operator
}

type resultsSorter []result

// Len is the number of elements in the collection.
func (s resultsSorter) Len() int {
	return len(s)
}

// Less reports whether the element with
// index i should sort before the element with index j.
func (s resultsSorter) Less(i, j int) bool {
	return s[i].handler.Priority() < s[j].handler.Priority()
}

// Swap swaps the elements with indexes i and j.
func (s resultsSorter) Swap(i, j int) {
	s[i], s[j] = s[j], s[i]
}

func (e *executor) check(producers []service.Producer, ats []string) bool {
	for _, at := range ats {
		for _, c := range producers {
			if c.ContentType() == at {
				return true
			}
		}
	}
	return false
}

func (e *executor) Acceptable(ct string) bool {
	for _, c := range e.consumers {
		if c.ContentType() == ct {
			return true
		}
	}
	return false
}

func (e *executor) Producible(ats []string) bool {
	return e.check(e.producers, ats) && e.check(e.errorProducers, ats)
}

func (e *executor) ContentTypeMap() map[string][]string {
	result := map[string][]string{}
	for _, c := range e.consumers {
		for _, p := range e.producers {
			ct := c.ContentType()
			result[ct] = append(result[ct], p.ContentType())
		}
	}
	return result
}

// Execute executes with context.
func (e *executor) Execute(ctx context.Context) (err error) {
	c := service.HTTPContextFrom(ctx)
	if c == nil {
		return service.NoContext.Error()
	}
	paramValues := make([]reflect.Value, 0, len(e.parameters))
	for _, p := range e.parameters {
		result, err := p.generator.Generate(ctx, c.ValueContainer(), e.consumers, p.name, p.targetType)
		if err != nil {
			return service.WriteError(ctx, e.errorProducers, err)
		}
		if result == nil {
			if p.defaultValue != nil {
				result = p.defaultValue
			} else {
				result = reflect.Zero(p.targetType).Interface()
			}
		}
		for _, operator := range p.operators {
			result, err = operator.Operate(ctx, p.name, result)
			if err != nil {
				return service.WriteError(ctx, e.errorProducers, err)
			}
		}

		if result == nil && !p.optional {
			return service.WriteError(ctx, e.errorProducers, requiredField.Error(p.name, p.generator.Source()))
		}

		if closer, ok := result.(io.Closer); ok {
			defer func() {
				if e := closer.Close(); e != nil && err == nil {
					// Need to print error here.
					err = e
				}
			}()
		}

		if result == nil {
			paramValues = append(paramValues, reflect.New(p.targetType).Elem())
		} else {
			paramValues = append(paramValues, reflect.ValueOf(result))
		}
	}

	code := e.code
	if code == 0 {
		switch c.Request().Method {
		case http.MethodPost:
			code = http.StatusCreated
		case http.MethodDelete:
			code = http.StatusNoContent
		default:
			code = http.StatusOK
		}
	}

	resultValues := e.function.Call(paramValues)
	for _, r := range e.results {
		v := resultValues[r.index]
		data := v.Interface()
		for _, operator := range r.operators {
			newData, err := operator.Operate(ctx, string(r.handler.Destination()), data)
			if err != nil {
				return err
			}
			data = newData
		}
		if data != nil {
			if closer, ok := data.(io.Closer); ok {
				defer func() {
					if e := closer.Close(); e != nil && err == nil {
						// Need to print error here.
						err = e
					}
				}()
			}
		}
		producers := e.producers
		if r.handler.Destination() == definition.Error {
			// Select correct producers to produce error.
			producers = e.errorProducers
		}
		goon, err := r.handler.Handle(ctx, producers, code, data)
		if err != nil {
			return err
		}
		if !goon {
			break
		}
	}
	resp := c.ResponseWriter()
	if resp.HeaderWritable() {
		resp.WriteHeader(code)
	}
	return nil
}

func order(i int) string {
	switch i % 10 {
	case 1:
		return fmt.Sprintf("%dst", i)
	case 2:
		return fmt.Sprintf("%dnd", i)
	case 3:
		return fmt.Sprintf("%drd", i)
	default:
		return fmt.Sprintf("%dth", i)
	}
}
